{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "959300d4",
   "metadata": {},
   "source": [
    "# MLX Local Pipelines\n",
    "\n",
    "MLX models can be run locally through the `MLXPipeline` class.\n",
    "\n",
    "The [MLX Community](https://huggingface.co/mlx-community) hosts over 150 models, all open source and publicly available on Hugging Face Model Hub a online platform where people can easily collaborate and build ML together.\n",
    "\n",
    "These can be called from LangChain either through this local pipeline wrapper or by calling their hosted inference endpoints through the MlXPipeline class. For more information on mlx, see the [examples repo](https://github.com/ml-explore/mlx-examples/tree/main/llms) notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c1b8450-5eaf-4d34-8341-2d785448a1ff",
   "metadata": {
    "tags": []
   },
   "source": [
    "To use, you should have the ``mlx-lm`` python [package installed](https://pypi.org/project/mlx-lm/), as well as [transformers](https://pypi.org/project/transformers/). You can also install `huggingface_hub`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d772b637-de00-4663-bd77-9bc96d798db2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%pip install --upgrade --quiet  mlx-lm transformers huggingface_hub"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91ad075f-71d5-4bc8-ab91-cc0ad5ef16bb",
   "metadata": {},
   "source": [
    "### Model Loading\n",
    "\n",
    "Models can be loaded by specifying the model parameters using the `from_model_id` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "165ae236-962a-4763-8052-c4836d78a5d2",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from langchain_community.llms.mlx_pipeline import MLXPipeline\n",
    "\n",
    "pipe = MLXPipeline.from_model_id(\n",
    "    \"mlx-community/quantized-gemma-2b-it\",\n",
    "    pipeline_kwargs={\"max_tokens\": 10, \"temp\": 0.1},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00104b27-0c15-4a97-b198-4512337ee211",
   "metadata": {},
   "source": [
    "They can also be loaded by passing in an existing `transformers` pipeline directly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f426a4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mlx_lm import load\n",
    "\n",
    "model, tokenizer = load(\"mlx-community/quantized-gemma-2b-it\")\n",
    "pipe = MLXPipeline(model=model, tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60e7ba8d",
   "metadata": {},
   "source": [
    "### Create Chain\n",
    "\n",
    "With the model loaded into memory, you can compose it with a prompt to\n",
    "form a chain."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3acf0069",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.prompts import PromptTemplate\n",
    "\n",
    "template = \"\"\"Question: {question}\n",
    "\n",
    "Answer: Let's think step by step.\"\"\"\n",
    "prompt = PromptTemplate.from_template(template)\n",
    "\n",
    "chain = prompt | pipe\n",
    "\n",
    "question = \"What is electroencephalography?\"\n",
    "\n",
    "print(chain.invoke({\"question\": question}))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
